from torch import optim


def get_optim(name):
    if name == 'SGD':
        return optim
    if name == 'Adam':
        return optim.Adam
    if name == 'AdamW':
        return optim.AdamW
